{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Calling BertTokenizer.from_pretrained() with the path to a single file or url is deprecated\n"
     ]
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('/Users/wangjianyu/Downloads/Bert-base/bert-base-uncased-vocab.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence = \"A Titan RTX has 24GB of VRAM\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized_sequence = tokenizer.tokenize(sequence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['a', 'titan', 'rt', '##x', 'has', '24', '##gb', 'of', 'vr', '##am']"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_sequence # 这里就是按照Vocab.txt,用WordPiece对句子进行分词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer.encode_plus(sequence, add_special_tokens = True, max_length=20, truncation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 1037, 16537, 19387, 2595, 2038, 2484, 18259, 1997, 27830, 3286, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs # 此时分词工具会返回所有与该模型能够正常工作有关的参数字典 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_sequence = inputs['input_ids']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[101, 1037, 16537, 19387, 2595, 2038, 2484, 18259, 1997, 27830, 3286, 102]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded_sequence # 这里因为我设置了add_special_tokens=True, 所以会添加CLS 等特殊字符， 返回的是tokenize对应的词的id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if we decode the previous sequence of ids\n",
    "decoded_sequence = tokenizer.decode(encoded_sequence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] a titan rtx has 24gb of vram [SEP]'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoded_sequence # return the raw sequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 关于可选参数Attention Mask的意义\n",
    "## 这个参数告诉模型什么tokens应该被attend to, 什么不应该"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_a = \"This is a short sequence.\"\n",
    "sequence_b = \"This is a rather long sequence. It is at least longer than the sequence A.\"\n",
    "\n",
    "encoded_sequence_a = tokenizer.encode_plus(sequence_a, add_special_tokens = True, max_length=50, truncation=True)['input_ids']\n",
    "encoded_sequence_b = tokenizer.encode_plus(sequence_b, add_special_tokens = True, max_length=50, truncation=True)['input_ids']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(8, 19)"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(encoded_sequence_a), len(encoded_sequence_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 因为两个有不同的长度，所以我们不能直接混合成一个tensor，第一个sequence必须进行padding，或者第二个要进行截断\n",
    "padded_sequences = tokenizer.encode_plus(sequence_a,sequence_b,pad_to_max_length=True, add_special_tokens =True, max_length=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': [101, 2023, 2003, 1037, 2460, 5537, 1012, 102, 2023, 2003, 1037, 2738, 2146, 5537, 1012, 2009, 2003, 2012, 2560, 2936, 2084, 1996, 5537, 1037, 1012, 102, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]}"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded_sequences # 这个就是一般dataset的getitem都这么写，就是借助pad_to_max_length这个参数来补全到设置好的最大的长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 可以看出Attention Mask就是说明了哪些属于真实value，哪些属于padding\n",
    "decoded = tokenizer.decode(padded_sequences['input_ids'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] this is a short sequence. [SEP] this is a rather long sequence. it is at least longer than the sequence a. [SEP] [PAD] [PAD] [PAD] [PAD]'"
      ]
     },
     "execution_count": 66,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "decoded # 这里可以看出使用两句话一起concat的时候，就是通过SEP进行分割"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 1,\n",
       " 0,\n",
       " 0,\n",
       " 0,\n",
       " 0]"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded_sequences['token_type_ids'] # 这里就是说清楚每句话从属部分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 一般来说使用Bert Embedding作为原始的Word Embedding的话\n",
    "# 1.这个就是dataset一般的设置方式@Word Embedding必须输入这些@BertModel\n",
    "class CustomDataset(Dataset):\n",
    "\n",
    "    def __init__(self, dataframe, tokenizer, max_len):\n",
    "        self.tokenizer = tokenizer\n",
    "        self.data = dataframe\n",
    "        self.comment_text = self.data.text\n",
    "        self.targets = self.data.label\n",
    "        self.max_len = max_len\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.comment_text)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        comment_text = self.comment_text[index]\n",
    "\n",
    "        inputs = self.tokenizer.encode_plus(\n",
    "            comment_text,\n",
    "            truncation=True,\n",
    "            add_special_tokens=True,\n",
    "            max_length=self.max_len,\n",
    "            pad_to_max_length=True,\n",
    "            return_token_type_ids=True\n",
    "        )\n",
    "        ids = inputs['input_ids']\n",
    "        mask = inputs['attention_mask']\n",
    "        token_type_ids = inputs[\"token_type_ids\"]\n",
    "\n",
    "        return {\n",
    "            'ids': torch.tensor(ids, dtype=torch.long),\n",
    "            'mask': torch.tensor(mask, dtype=torch.long),\n",
    "            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),\n",
    "            'targets': torch.tensor(self.targets[index], dtype=torch.float)\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2.Creating the dataset and dataloader for the neural network\n",
    "MAX_LEN = 256 ## 这里的Max Length就是encode_plus的最大长度, 不过呢由于WordPiece的存在，一般设置比较大，都不会影响的, 约等于max_position_embeddings\n",
    "train_size = 0.8\n",
    "train_dataset = train_df.sample(frac=train_size,random_state=7)\n",
    "valid_dataset = train_df.drop(train_dataset.index).reset_index(drop=True)\n",
    "train_dataset = train_dataset.reset_index(drop=True)\n",
    "\n",
    "\n",
    "print(\"FULL Dataset: {}\".format(train_df.shape))\n",
    "print(\"TRAIN Dataset: {}\".format(train_dataset.shape))\n",
    "print(\"VALID Dataset: {}\".format(valid_dataset.shape))\n",
    "print(\"TEST Dataset: {}\".format(test_df.shape))\n",
    "\n",
    "train_set = CustomDataset(train_dataset, tokenizer, MAX_LEN)\n",
    "valid_set = CustomDataset(valid_dataset, tokenizer, MAX_LEN)\n",
    "test_set = CustomDataset(test_df, tokenizer, MAX_LEN)\n",
    "\n",
    "\n",
    "TRAIN_BATCH_SIZE = 32\n",
    "VALID_BATCH_SIZE = 16\n",
    "TEST_BATCH_SIZE = 16\n",
    "train_params = {'batch_size': TRAIN_BATCH_SIZE,\n",
    "                'shuffle': True}\n",
    "\n",
    "valid_params = {'batch_size': VALID_BATCH_SIZE,\n",
    "                'shuffle': True}\n",
    "\n",
    "test_params = {'batch_size': TEST_BATCH_SIZE,\n",
    "                'shuffle': False}\n",
    "\n",
    "train_loader = DataLoader(train_set, **train_params)\n",
    "valid_loader = DataLoader(valid_set, **valid_params)\n",
    "test_loader = DataLoader(test_set, **test_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. Bert Model\n",
    "class BERTClass(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(BERTClass, self).__init__()\n",
    "        self.config = BertConfig.from_pretrained('../emb/bert-mini/bert_config.json', output_hidden_states=True)\n",
    "        self.l1 = BertModel.from_pretrained('../emb/bert-mini/pytorch_model.bin', config=self.config)\n",
    "        self.bilstm1 = torch.nn.LSTM(512, 64, 1, bidirectional=True)\n",
    "        self.l2 = torch.nn.Linear(128, 64)\n",
    "        self.a1 = torch.nn.ReLU()\n",
    "        self.l3 = torch.nn.Dropout(0.3)\n",
    "        self.l4 = torch.nn.Linear(64, 14)\n",
    "    \n",
    "    def forward(self, ids, mask, token_type_ids):\n",
    "        sequence_output, pooler_output, hidden_states= self.l1(ids, attention_mask=mask, token_type_ids=token_type_ids)\n",
    "        # sequence_output: sequence of hidden_states at the output of the last layer of the model (batchsize, seq_len, hidden_size)\n",
    "        # pooler_output: last layer hidden-state of the first token of the sequence, (batch_size, hidden_size)\n",
    "        # hidden_states: (layers, batch_size, seq_len, hidden_size), each layer of the output\n",
    "        bs = len(sequence_output) \n",
    "        h12 = hidden_states[-1][:,0].view(1,bs,256)\n",
    "        h11 = hidden_states[-2][:,0].view(1,bs,256)\n",
    "        concat_hidden = torch.cat((h12,h11),2)\n",
    "        x, _ = self.bilstm1(concat_hidden)\n",
    "        x = self.l2(x.view(bs,128))\n",
    "        x = self.a1(x)\n",
    "        x = self.l3(x)\n",
    "        output = self.l4(x)\n",
    "        return output\n",
    "\n",
    "net = BERTClass()\n",
    "net.to(device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
